Maximum sum BST in binary tree [DFS+Stack,DFS+Recursion]

Time: O(N); Space: O(H); hard

Given a binary tree root, the task is to return the maximum sum of all keys of any sub-tree which is also a Binary Search Tree (BST).

Assume a BST is defined as follows:

  • The left subtree of a node contains only nodes with keys less than the node’s key.

  • The right subtree of a node contains only nodes with keys greater than the node’s key.

  • Both the left and right subtrees must also be binary search trees.

Example 1:

Input: root = {TreeNode} [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]

Output: 20

Explanation:

  • Maximum sum in a valid Binary search tree is obtained in root node with key equal to 3.

Example 2:

Input: root = {TreeNode} [4,3,null,1,2]

Output: 2

Explanation: Maximum sum in a valid Binary search tree is obtained in a single root node with key equal to 2.

Example 3:

Input: root = {TreeNode} [-4,-2,-5]

Output: 0

Explanation:

  • All values are negatives. Return an empty BST.

Example 4:

Input: root = {TreeNode} [2,1,3]

Output: 6

Example 5:

Input: root = {TreeNode} [5,4,8,3,null,6,3]

Output: 7

Constraints:

  • Each tree has at most 40000 nodes..

  • Each node’s value is between [-4 * 10^4 , 4 * 10^4].

Hints:

  1. Create a datastructure with 4 parameters: (sum, isBST, maxLeft, minLeft).

  2. In each node compute theses parameters, following the conditions of a Binary Search Tree.

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

1. DFS solution with stack

[2]:
class Solution1(object):
    """
    Time: O(N)
    Space: O(H)
    """
    def maxSumBST(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        result = 0
        stk = [[root, None, []]]

        while stk:
            node, tmp, ret = stk.pop()
            if tmp:
                lvalid, lsum, lmin, lmax = tmp[0]
                rvalid, rsum, rmin, rmax = tmp[1]
                if lvalid and rvalid and lmax < node.val < rmin:
                    total = lsum + node.val + rsum
                    result = max(result, total)
                    ret[:] = [True, total, min(lmin, node.val), max(node.val, rmax)]
                    continue
                ret[:] = [False, 0, 0, 0]
                continue
            if not node:
                ret[:] = [True, 0, float("inf"), float("-inf")]
                continue

            new_tmp = [[], []]
            stk.append([node, new_tmp, ret])
            stk.append([node.right, None, new_tmp[1]])
            stk.append([node.left, None, new_tmp[0]])

        return result
[3]:
s = Solution1()

root = TreeNode(1)
root.left = TreeNode(4)
root.right = TreeNode(3)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.right.left = TreeNode(2)
root.right.right = TreeNode(5)
root.right.right.left = TreeNode(4)
root.right.right.right = TreeNode(6)
assert s.maxSumBST(root) == 20

root = TreeNode(4)
root.left = TreeNode(3)
root.left.left = TreeNode(1)
root.left.right = TreeNode(2)
assert s.maxSumBST(root) == 2

root = TreeNode(-4)
root.left = TreeNode(-2)
root.right = TreeNode(-5)
assert s.maxSumBST(root) == 0

root = TreeNode(2)
root.left = TreeNode(1)
root.right = TreeNode(3)
assert s.maxSumBST(root) == 6

root = TreeNode(5)
root.left = TreeNode(4)
root.right = TreeNode(8)
root.left.left = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(3)
assert s.maxSumBST(root) == 7

2. DFS solution with recursion

[4]:
class Solution2(object):
    def maxSumBST(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        def dfs(node, result):
            if not node:
                return True, 0, float("inf"), float("-inf")

            lvalid, lsum, lmin, lmax = dfs(node.left, result)

            rvalid, rsum, rmin, rmax = dfs(node.right, result)

            if lvalid and rvalid and lmax < node.val < rmin:
                total = lsum + node.val + rsum
                result[0] = max(result[0], total)
                return True, total, min(lmin, node.val), max(node.val, rmax)

            return False, 0, 0, 0

        result = [0]
        dfs(root, result)

        return result[0]
[5]:
s = Solution2()

root = TreeNode(1)
root.left = TreeNode(4)
root.right = TreeNode(3)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.right.left = TreeNode(2)
root.right.right = TreeNode(5)
root.right.right.left = TreeNode(4)
root.right.right.right = TreeNode(6)
assert s.maxSumBST(root) == 20

root = TreeNode(4)
root.left = TreeNode(3)
root.left.left = TreeNode(1)
root.left.right = TreeNode(2)
assert s.maxSumBST(root) == 2

root = TreeNode(-4)
root.left = TreeNode(-2)
root.right = TreeNode(-5)
assert s.maxSumBST(root) == 0

root = TreeNode(2)
root.left = TreeNode(1)
root.right = TreeNode(3)
assert s.maxSumBST(root) == 6

root = TreeNode(5)
root.left = TreeNode(4)
root.right = TreeNode(8)
root.left.left = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(3)
assert s.maxSumBST(root) == 7